{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Training"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/training.ipynb)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Install chemprop from GitHub if running in Google Colab\n",
                "import os\n",
                "\n",
                "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
                "    try:\n",
                "        import chemprop\n",
                "    except ImportError:\n",
                "        !git clone https://github.com/chemprop/chemprop.git\n",
                "        %cd chemprop\n",
                "        !pip install .\n",
                "        %cd examples"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Import packages"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 1,
            "metadata": {},
            "outputs": [],
            "source": [
                "from pathlib import Path\n",
                "\n",
                "from lightning import pytorch as pl\n",
                "from lightning.pytorch.callbacks import ModelCheckpoint\n",
                "import pandas as pd\n",
                "\n",
                "from chemprop import data, featurizers, models, nn"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Change data inputs here"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "metadata": {},
            "outputs": [],
            "source": [
                "chemprop_dir = Path.cwd().parent\n",
                "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\" # path to your data .csv file\n",
                "num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading\n",
                "smiles_column = 'smiles' # name of the column containing SMILES strings\n",
                "target_columns = ['lipo'] # list of names of the columns containing targets"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Load data"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>smiles</th>\n",
                            "      <th>lipo</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14</td>\n",
                            "      <td>3.54</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...</td>\n",
                            "      <td>-1.18</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl</td>\n",
                            "      <td>3.69</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...</td>\n",
                            "      <td>3.37</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...</td>\n",
                            "      <td>3.10</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>...</th>\n",
                            "      <td>...</td>\n",
                            "      <td>...</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>95</th>\n",
                            "      <td>CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...</td>\n",
                            "      <td>2.20</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>96</th>\n",
                            "      <td>CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...</td>\n",
                            "      <td>2.04</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>97</th>\n",
                            "      <td>CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...</td>\n",
                            "      <td>4.49</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>98</th>\n",
                            "      <td>COc1ccc(Cc2c(N)n[nH]c2N)cc1</td>\n",
                            "      <td>0.20</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>99</th>\n",
                            "      <td>CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...</td>\n",
                            "      <td>2.00</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "<p>100 rows × 2 columns</p>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "                                               smiles  lipo\n",
                            "0             Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14  3.54\n",
                            "1   COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18\n",
                            "2              COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl  3.69\n",
                            "3   OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...  3.37\n",
                            "4   Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...  3.10\n",
                            "..                                                ...   ...\n",
                            "95  CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...  2.20\n",
                            "96  CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...  2.04\n",
                            "97  CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...  4.49\n",
                            "98                        COc1ccc(Cc2c(N)n[nH]c2N)cc1  0.20\n",
                            "99  CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...  2.00\n",
                            "\n",
                            "[100 rows x 2 columns]"
                        ]
                    },
                    "execution_count": 3,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "df_input = pd.read_csv(input_path)\n",
                "df_input"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Get SMILES and targets"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 4,
            "metadata": {},
            "outputs": [],
            "source": [
                "smis = df_input.loc[:, smiles_column].values\n",
                "ys = df_input.loc[:, target_columns].values"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 5,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',\n",
                            "       'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',\n",
                            "       'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',\n",
                            "       'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',\n",
                            "       'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],\n",
                            "      dtype=object)"
                        ]
                    },
                    "execution_count": 5,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "smis[:5] # show first 5 SMILES strings"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 6,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "array([[ 3.54],\n",
                            "       [-1.18],\n",
                            "       [ 3.69],\n",
                            "       [ 3.37],\n",
                            "       [ 3.1 ]])"
                        ]
                    },
                    "execution_count": 6,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "ys[:5] # show first 5 targets"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Get molecule datapoints"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 7,
            "metadata": {},
            "outputs": [],
            "source": [
                "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Perform data splitting for training, validation, and testing"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 8,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "['SCAFFOLD_BALANCED',\n",
                            " 'RANDOM_WITH_REPEATED_SMILES',\n",
                            " 'RANDOM',\n",
                            " 'KENNARD_STONE',\n",
                            " 'KMEANS']"
                        ]
                    },
                    "execution_count": 8,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# available split types\n",
                "list(data.SplitType.keys())"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Chemprop's `make_split_indices` function will always return a two- (if no validation) or three-length tuple.\n",
                "Each member is a list of length `num_replicates`.\n",
                "The inner lists then contain the actual indices for splitting.\n",
                "\n",
                "The type signature for this return type is `tuple[list[list[int]], ...]`."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 9,
            "metadata": {},
            "outputs": [],
            "source": [
                "mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits\n",
                "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists\n",
                "train_data, val_data, test_data = data.split_data_by_indices(\n",
                "    all_data, train_indices, val_indices, test_indices\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Chemprop's splitting function implements our preferred method of data splitting, which is random replication.\n",
                "It's also possible to add your own custom cross-validation splitter, such as one of those as implemented in scikit-learn, as long as you get the data into the same `tuple[list[list[int]], ...]` data format with something like this:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 10,
            "metadata": {},
            "outputs": [],
            "source": [
                "from sklearn.model_selection import KFold\n",
                "\n",
                "k_splits = KFold(n_splits=5)\n",
                "k_train_indices, k_val_indices, k_test_indices = [], [], []\n",
                "for fold in k_splits.split(mols):\n",
                "    k_train_indices.append(fold[0])\n",
                "    k_val_indices.append([])\n",
                "    k_test_indices.append(fold[1])\n",
                "k_train_data, _, k_test_data = data.split_data_by_indices(\n",
                "    all_data, k_train_indices, None, k_test_indices\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Get MoleculeDataset\n",
                "Recall that the data is in a list equal in length to the number of replicates, so we select the zero index of the list to get the first replicate."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 11,
            "metadata": {},
            "outputs": [],
            "source": [
                "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n",
                "\n",
                "train_dset = data.MoleculeDataset(train_data[0], featurizer)\n",
                "scaler = train_dset.normalize_targets()\n",
                "\n",
                "val_dset = data.MoleculeDataset(val_data[0], featurizer)\n",
                "val_dset.normalize_targets(scaler)\n",
                "\n",
                "test_dset = data.MoleculeDataset(test_data[0], featurizer)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Get DataLoader"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 12,
            "metadata": {},
            "outputs": [],
            "source": [
                "train_loader = data.build_dataloader(train_dset, num_workers=num_workers)\n",
                "val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)\n",
                "test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Change Message-Passing Neural Network (MPNN) inputs here"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Message Passing\n",
                "A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.\n",
                "\n",
                "Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 13,
            "metadata": {},
            "outputs": [],
            "source": [
                "mp = nn.BondMessagePassing()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Aggregation\n",
                "An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.\n",
                "\n",
                "Available options can be found in ` nn.agg.AggregationRegistry`, including\n",
                "- `agg = nn.MeanAggregation()`\n",
                "- `agg = nn.SumAggregation()`\n",
                "- `agg = nn.NormAggregation()`"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 14,
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "ClassRegistry {\n",
                        "    'mean': <class 'chemprop.nn.agg.MeanAggregation'>,\n",
                        "    'sum': <class 'chemprop.nn.agg.SumAggregation'>,\n",
                        "    'norm': <class 'chemprop.nn.agg.NormAggregation'>\n",
                        "}\n"
                    ]
                }
            ],
            "source": [
                "print(nn.agg.AggregationRegistry)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 15,
            "metadata": {},
            "outputs": [],
            "source": [
                "agg = nn.MeanAggregation()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Feed-Forward Network (FFN)\n",
                "\n",
                "A `FFN` takes the aggregated representations and make target predictions.\n",
                "\n",
                "Available options can be found in `nn.PredictorRegistry`.\n",
                "\n",
                "For regression:\n",
                "- `ffn = nn.RegressionFFN()`\n",
                "- `ffn = nn.MveFFN()`\n",
                "- `ffn = nn.EvidentialFFN()`\n",
                "\n",
                "For classification:\n",
                "- `ffn = nn.BinaryClassificationFFN()`\n",
                "- `ffn = nn.BinaryDirichletFFN()`\n",
                "- `ffn = nn.MulticlassClassificationFFN()`\n",
                "- `ffn = nn.MulticlassDirichletFFN()`\n",
                "\n",
                "For spectral:\n",
                "- `ffn = nn.SpectralFFN()` # will be available in future version"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 16,
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "ClassRegistry {\n",
                        "    'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,\n",
                        "    'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,\n",
                        "    'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,\n",
                        "    'regression-quantile': <class 'chemprop.nn.predictors.QuantileFFN'>,\n",
                        "    'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,\n",
                        "    'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,\n",
                        "    'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,\n",
                        "    'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,\n",
                        "    'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>\n",
                        "}\n"
                    ]
                }
            ],
            "source": [
                "print(nn.PredictorRegistry)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 17,
            "metadata": {},
            "outputs": [],
            "source": [
                "output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 18,
            "metadata": {},
            "outputs": [],
            "source": [
                "ffn = nn.RegressionFFN(output_transform=output_transform)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Batch Norm\n",
                "A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.\n",
                "\n",
                "Whether to use batch norm"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 19,
            "metadata": {},
            "outputs": [],
            "source": [
                "batch_norm = True"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Metrics\n",
                "`Metrics` are the ways to evaluate the performance of model predictions.\n",
                "\n",
                "Available options can be found in `metrics.MetricRegistry`, including"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 20,
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "ClassRegistry {\n",
                        "    'mse': <class 'chemprop.nn.metrics.MSE'>,\n",
                        "    'mae': <class 'chemprop.nn.metrics.MAE'>,\n",
                        "    'rmse': <class 'chemprop.nn.metrics.RMSE'>,\n",
                        "    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSE'>,\n",
                        "    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAE'>,\n",
                        "    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSE'>,\n",
                        "    'r2': <class 'chemprop.nn.metrics.R2Score'>,\n",
                        "    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,\n",
                        "    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,\n",
                        "    'roc': <class 'chemprop.nn.metrics.BinaryAUROC'>,\n",
                        "    'prc': <class 'chemprop.nn.metrics.BinaryAUPRC'>,\n",
                        "    'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracy'>,\n",
                        "    'f1': <class 'chemprop.nn.metrics.BinaryF1Score'>\n",
                        "}\n"
                    ]
                }
            ],
            "source": [
                "print(nn.metrics.MetricRegistry)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()] # Only the first metric is used for training and early stopping"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Constructs MPNN"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 22,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "MPNN(\n",
                            "  (message_passing): BondMessagePassing(\n",
                            "    (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
                            "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
                            "    (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
                            "    (dropout): Dropout(p=0.0, inplace=False)\n",
                            "    (tau): ReLU()\n",
                            "    (V_d_transform): Identity()\n",
                            "    (graph_transform): Identity()\n",
                            "  )\n",
                            "  (agg): MeanAggregation()\n",
                            "  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
                            "  (predictor): RegressionFFN(\n",
                            "    (ffn): MLP(\n",
                            "      (0): Sequential(\n",
                            "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
                            "      )\n",
                            "      (1): Sequential(\n",
                            "        (0): ReLU()\n",
                            "        (1): Dropout(p=0.0, inplace=False)\n",
                            "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
                            "      )\n",
                            "    )\n",
                            "    (criterion): MSE(task_weights=[[1.0]])\n",
                            "    (output_transform): UnscaleTransform()\n",
                            "  )\n",
                            "  (X_d_transform): Identity()\n",
                            "  (metrics): ModuleList(\n",
                            "    (0): RMSE(task_weights=[[1.0]])\n",
                            "    (1): MAE(task_weights=[[1.0]])\n",
                            "    (2): MSE(task_weights=[[1.0]])\n",
                            "  )\n",
                            ")"
                        ]
                    },
                    "execution_count": 22,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)\n",
                "mpnn"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Set up trainer"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 23,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "GPU available: False, used: False\n",
                        "TPU available: False, using: 0 TPU cores\n",
                        "HPU available: False, using: 0 HPUs\n"
                    ]
                }
            ],
            "source": [
                "# Configure model checkpointing\n",
                "checkpointing = ModelCheckpoint(\n",
                "    \"checkpoints\",  # Directory where model checkpoints will be saved\n",
                "    \"best-{epoch}-{val_loss:.2f}\",  # Filename format for checkpoints, including epoch and validation loss\n",
                "    \"val_loss\",  # Metric used to select the best checkpoint (based on validation loss)\n",
                "    mode=\"min\",  # Save the checkpoint with the lowest validation loss (minimization objective)\n",
                "    save_last=True,  # Always save the most recent checkpoint, even if it's not the best\n",
                ")\n",
                "\n",
                "\n",
                "trainer = pl.Trainer(\n",
                "    logger=False,\n",
                "    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.\n",
                "    enable_progress_bar=True,\n",
                "    accelerator=\"auto\",\n",
                "    devices=1,\n",
                "    max_epochs=20, # number of epochs to train for\n",
                "    callbacks=[checkpointing], # Use the configured checkpoint callback\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Start training"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 24,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.\n",
                        "Loading `train_dataloader` to estimate number of stepping batches.\n",
                        "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
                        "\n",
                        "  | Name            | Type               | Params | Mode \n",
                        "---------------------------------------------------------------\n",
                        "0 | message_passing | BondMessagePassing | 227 K  | train\n",
                        "1 | agg             | MeanAggregation    | 0      | train\n",
                        "2 | bn              | BatchNorm1d        | 600    | train\n",
                        "3 | predictor       | RegressionFFN      | 90.6 K | train\n",
                        "4 | X_d_transform   | Identity           | 0      | train\n",
                        "5 | metrics         | ModuleList         | 0      | train\n",
                        "---------------------------------------------------------------\n",
                        "318 K     Trainable params\n",
                        "0         Non-trainable params\n",
                        "318 K     Total params\n",
                        "1.276     Total estimated model params size (MB)\n",
                        "25        Modules in train mode\n",
                        "0         Modules in eval mode\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "                                                                           "
                    ]
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, train_loss_step=0.880, val_loss=0.886, train_loss_epoch=1.010]        "
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  3.89it/s, train_loss_step=0.0757, val_loss=0.789, train_loss_epoch=0.066] "
                    ]
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  3.75it/s, train_loss_step=0.0757, val_loss=0.789, train_loss_epoch=0.066]\n"
                    ]
                }
            ],
            "source": [
                "trainer.fit(mpnn, train_loader, val_loader)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Test results"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 25,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:145: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n",
                        "Restoring states from the checkpoint path at /home/knathan/chemprop/examples/checkpoints/best-epoch=16-val_loss=0.78.ckpt\n",
                        "Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/checkpoints/best-epoch=16-val_loss=0.78.ckpt\n",
                        "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 19.89it/s]\n"
                    ]
                },
                {
                    "data": {
                        "text/html": [
                            "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
                            "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
                            "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
                            "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mae          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.643399715423584     </span>│\n",
                            "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/rmse         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9120855927467346     </span>│\n",
                            "└───────────────────────────┴───────────────────────────┘\n",
                            "</pre>\n"
                        ],
                        "text/plain": [
                            "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
                            "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
                            "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
                            "│\u001b[36m \u001b[0m\u001b[36m        test/mae         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.643399715423584    \u001b[0m\u001b[35m \u001b[0m│\n",
                            "│\u001b[36m \u001b[0m\u001b[36m        test/rmse        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9120855927467346    \u001b[0m\u001b[35m \u001b[0m│\n",
                            "└───────────────────────────┴───────────────────────────┘\n"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                }
            ],
            "source": [
                "results = trainer.test(dataloaders=test_loader)"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "chemprop",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.11.8"
        },
        "orig_nbformat": 4
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
